{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "nerfies_eval_skeleton.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "m8pnxFaaaFFs"
      },
      "source": [
        "# @title Basic Imports\n",
        "\n",
        "from collections import defaultdict\n",
        "import collections\n",
        "\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "import pandas as pd\n",
        "\n",
        "from IPython import display as ipd\n",
        "from PIL import Image\n",
        "import io\n",
        "import imageio\n",
        "import numpy as np\n",
        "import copy\n",
        "\n",
        "import gin\n",
        "gin.enter_interactive_mode()\n",
        "\n",
        "from six.moves import reload_module\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "from absl import logging\n",
        "from pprint import pprint\n",
        "\n",
        "\n",
        "def myprint(msg, *args, **kwargs):\n",
        " print(msg % args)\n",
        "logging.info = myprint \n",
        "logging.warn = myprint "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vblVxNPuZ1Cz"
      },
      "source": [
        "# @title Utility methods.\n",
        "\n",
        "import json\n",
        "from scipy.ndimage import morphology\n",
        "import math\n",
        "import cv2\n",
        "\n",
        "\n",
        "def load_image(x):\n",
        "  try:\n",
        "    x = image_utils.load_image(x)\n",
        "  except:\n",
        "    print(f'Could not read image file {x}')\n",
        "    raise\n",
        "  x = x.astype(np.float32) / 255.0\n",
        "  return x\n",
        "\n",
        "\n",
        "def load_images(paths):\n",
        "  return utils.parallel_map(load_image, paths, show_pbar=False)\n",
        "\n",
        "\n",
        "def load_mask(path):\n",
        "  return load_image(path)\n",
        "\n",
        "\n",
        "def load_masks(paths):\n",
        "  return utils.parallel_map(load_mask, paths, show_pbar=False)\n",
        "\n",
        "\n",
        "def crop_image(image, left=0, right=0, top=0, bottom=0):\n",
        "  pad_width = [max(0, -x) for x in [top, bottom, left, right]]\n",
        "  if any(pad_width):\n",
        "    image = np.pad(image, pad_width=pad_width, mode='constant')\n",
        "  crop_coords = [max(0, x) for x in (top, bottom, left, right)]\n",
        "  return image[crop_coords[0]:-crop_coords[1], crop_coords[2]:-crop_coords[3], :]\n",
        "\n",
        "\n",
        "def scale_image(image, scale, interpolation=None):\n",
        "  if interpolation is None:\n",
        "    interpolation = cv2.INTER_AREA if scale < 1 else cv2.INTER_LANCZOS4\n",
        "  height, width = image.shape[:2]\n",
        "  new_height, new_width = int(scale * height), int(scale * width)\n",
        "  image = cv2.resize(\n",
        "      image, (new_width, new_height), interpolation=interpolation)\n",
        "  return image"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z4mBp_dmzBNt"
      },
      "source": [
        "# @title Cropping Code\n",
        "\n",
        "def compute_nv_crop(image, scale=2.0):\n",
        "  height, width = image.shape[:2]\n",
        "  target_height = int(1024 * scale)\n",
        "  target_width = int(667 * scale)\n",
        "\n",
        "  margin = 100\n",
        "  scale_y = (target_height + 1 + margin) / image.shape[0]\n",
        "  scale_x = (target_width + 1 + margin) / image.shape[1]\n",
        "  scale = max(scale_x, scale_y)\n",
        "  # image = scale_image(image, scale)\n",
        "\n",
        "  new_shape = (int(height * scale), int(width * scale))\n",
        "  if new_shape[0] > target_height:\n",
        "    crop_top = int(math.floor((new_shape[0] - target_height) / 2))\n",
        "    crop_bottom = int(math.ceil((new_shape[0] - target_height) / 2))\n",
        "  else:\n",
        "    crop_top = 0\n",
        "    crop_bottom = 1\n",
        "  if new_shape[1] > target_width:\n",
        "    crop_left = int(math.floor((new_shape[1] - target_width) / 2))\n",
        "    crop_right = int(math.ceil((new_shape[1] - target_width) / 2))\n",
        "  else:\n",
        "    crop_left = 0\n",
        "    crop_right = 1\n",
        " \n",
        "  crop1 = np.array([crop_left, crop_right, crop_top, crop_bottom], np.float32)\n",
        "  new_shape = (new_shape[0] - crop_top - crop_bottom, \n",
        "               new_shape[1] - crop_right - crop_left)\n",
        "  crop1 /= scale\n",
        "\n",
        "  new_shape = (int(new_shape[0] * 0.51), int(new_shape[1] * 0.51))\n",
        "  if new_shape[0] > 1024:\n",
        "    crop_top = int(math.floor((new_shape[0] - 1024) / 2))\n",
        "    crop_bottom = int(math.ceil((new_shape[0] - 1024) / 2))\n",
        "  else:\n",
        "    crop_top = 0\n",
        "    crop_bottom = 1\n",
        "  if new_shape[1] > 667:\n",
        "    crop_left = int(math.floor((new_shape[1] - 667) / 2))\n",
        "    crop_right = int(math.ceil((new_shape[1] - 667) / 2))\n",
        "  else:\n",
        "    crop_left = 0\n",
        "    crop_right = 1\n",
        "\n",
        "  crop2 = np.array([crop_left, crop_right, crop_top, crop_bottom], np.float32)\n",
        "  crop2 = crop2 / 0.51 / scale\n",
        "\n",
        "  crop = (crop1 + crop2).astype(np.uint32).tolist()\n",
        "  return crop\n",
        "\n",
        "\n",
        "def _nv_crop(image, scale=2.0):\n",
        "  crop = compute_nv_crop(image, scale)\n",
        "  return crop_image(image, *crop)\n",
        "\n",
        "\n",
        "def nv_crop(image, scale=2.0):\n",
        "  height, width = image.shape[:2]\n",
        "  rotated = False\n",
        "  if width > height:\n",
        "    image = np.rot90(image)\n",
        "    rotated = True\n",
        "  \n",
        "  image = _nv_crop(image, scale=scale)\n",
        "  if rotated:\n",
        "    image = np.rot90(image, -1)\n",
        "  \n",
        "  return image\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VvKJPeTMzOyA"
      },
      "source": [
        "# @title Undistort Code\n",
        "\n",
        "def crop_camera_invalid_pixels(source_camera, target_camera):\n",
        "  \"\"\"Crops the camera viewport to only contain valid pixels.\n",
        "\n",
        "  This method crops \"invalid\" pixels that can occur when an image is warped\n",
        "  (e.g., when an image is undistorted). The method computes the outer border\n",
        "  pixels of the original image and projects them to the warped camera. It then\n",
        "  computes the inscribed rectangle bounds where there are guaranteed to be no\n",
        "  invalid pixels. The method then adjusts the image size and principal point\n",
        "  to crop the viewport to this inscribed rectangle.\n",
        "\n",
        "  Example:\n",
        "    # Cropping a camera `target_camera`.\n",
        "    target_cropped_camera = crop_camera_invalid_pixels(\n",
        "        source_camera, target_camera)\n",
        "    warped_cropped_image = warp_image_at_infinity(\n",
        "        source_camera, target_cropped_camera, image)\n",
        "\n",
        "  Args:\n",
        "    source_camera: The camera corresponding to the source image.\n",
        "    target_camera: The target camera on which to warp the source image.\n",
        "\n",
        "  Returns:\n",
        "    A camera which is the same as the target_camera but with a cropped\n",
        "      viewport. The width, height, and principal points will be changed.\n",
        "  \"\"\"\n",
        "  # Compute rays from original camera.\n",
        "  source_pixels = source_camera.GetPixelCenters()\n",
        "  source_pixel_rays = source_camera.PixelsToRays(source_pixels)\n",
        "  source_pixel_dirs = np.insert(source_pixel_rays, 3, 0, axis=2)\n",
        "  source_pixels, _ = target_camera.Project(source_pixel_dirs)\n",
        "\n",
        "  # Compute border pixel bounds.\n",
        "  top_max_y = max(source_pixels[0, :, 1])\n",
        "  bottom_min_y = min(source_pixels[-1, :, 1])\n",
        "  left_max_x = max(source_pixels[:, 0, 0])\n",
        "  right_min_x = min(source_pixels[:, -1, 0])\n",
        "\n",
        "  # How much do we have to scale the image?\n",
        "  cx = target_camera.PrincipalPointX()\n",
        "  cy = target_camera.PrincipalPointY()\n",
        "  width = target_camera.ImageSizeX()\n",
        "  height = target_camera.ImageSizeY()\n",
        "  # Take the maximum of the top/bottom crop for the vertical scale and the\n",
        "  # left/right crop for the horizontal scale.\n",
        "  scale_x = 1.0 / max(cx / (cx - left_max_x),\n",
        "                      (width - 0.5 - cx) / (right_min_x - cx))\n",
        "  scale_y = 1.0 / max(cy / (cy - top_max_y),\n",
        "                      (height - 0.5 - cy) / (bottom_min_y - cy))\n",
        "\n",
        "  new_width = int(scale_x * width)\n",
        "  new_height = int(scale_y * height)\n",
        "\n",
        "  # Move principal point based on new image size.\n",
        "  new_camera = target_camera.Copy()\n",
        "  new_cx = cx * new_width / width\n",
        "  new_cy = cy * new_height / height\n",
        "  new_camera.SetPrincipalPoint(new_cx, new_cy)\n",
        "  new_camera.SetImageSize(new_width, new_height)\n",
        "\n",
        "  return new_camera\n",
        "\n",
        "\n",
        "def undistort_camera(camera, crop_blank_pixels=False, set_ideal_pp=True):\n",
        "  # Create a copy of the camera with no distortion.\n",
        "  undistorted_camera = camera.Copy()\n",
        "  undistorted_camera.SetRadialDistortion(0, 0, 0)\n",
        "  undistorted_camera.SetTangentialDistortion(0, 0)\n",
        "  if set_ideal_pp:\n",
        "    undistorted_camera.SetIdealPrincipalPoint()\n",
        "\n",
        "  new_camera = undistorted_camera\n",
        "  if crop_blank_pixels:\n",
        "    new_camera = crop_camera_blank_pixels(camera, undistorted_camera)\n",
        "\n",
        "  return new_camera\n",
        "\n",
        "\n",
        "def undistort_image(image, camera, set_ideal_pp=True, crop_blank_pixels=False):\n",
        "  if isinstance(camera, cam.Camera):\n",
        "    camera = camera.to_sfm_camera()\n",
        "  undistorted_camera = undistort_camera(\n",
        "      camera, crop_blank_pixels, set_ideal_pp=set_ideal_pp)\n",
        "  undistorted_image = warp_image.warp_image_at_infinity(\n",
        "      camera, undistorted_camera, image, mode='constant')\n",
        "  return undistorted_image, undistorted_camera\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S11sCgRn2RsT"
      },
      "source": [
        "# @title Metrics\n",
        "\n",
        "lpips_model = hub.load('...')\n",
        "\n",
        "  \n",
        "def compute_ssim(target, pred):\n",
        "  target = tf.convert_to_tensor(target)\n",
        "  pred = tf.convert_to_tensor(pred)\n",
        "  return tf.image.ssim_multiscale(target, pred, max_val=1.0)\n",
        "\n",
        "\n",
        "def compute_lpips(target, pred):\n",
        "    target = tf.convert_to_tensor(target)\n",
        "    pred = tf.convert_to_tensor(pred)\n",
        "    rgb_tensor_batch = tf.expand_dims(pred, axis=0)\n",
        "    target_tensor_batch = tf.expand_dims(target, axis=0)\n",
        "    return lpips_model(rgb_tensor_batch, target_tensor_batch)\n",
        "\n",
        "\n",
        "def compute_mse(x, y, sample_weight=None):\n",
        "  \"\"\"Calculates MSE loss.\n",
        "\n",
        "  Args:\n",
        "    x: [..., 3] float32. RGB.\n",
        "    y: [..., 3] float32. RGB.\n",
        "    sample_weight: [...] float32. Per-color weight.\n",
        "\n",
        "  Returns:\n",
        "    scalar float32. Average squared error across all entries in x, y.\n",
        "  \"\"\"\n",
        "  if sample_weight is None:\n",
        "    return np.mean((x - y)**2)\n",
        "\n",
        "  if sample_weight.ndim == 2:\n",
        "    sample_weight = sample_weight[..., None]\n",
        "  sample_weight = np.broadcast_to(sample_weight, x.shape)\n",
        "  diff = ((x - y)*sample_weight)\n",
        "  numer = np.sum(diff ** 2)\n",
        "  denom = np.sum(sample_weight)\n",
        "  return numer / denom\n",
        "\n",
        "\n",
        "def mse_to_psnr(x):\n",
        "  # return -10. * np.log10(x) / np.log10(10.)\n",
        "  return 10 * np.log10(1 / x)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fGkjB7HCaN7H"
      },
      "source": [
        "from collections import defaultdict\n",
        "\n",
        "experiment_root = gpath.GPath()\n",
        "dataset_root = gpath.GPath()\n",
        "\n",
        "dataset_names = {\n",
        "}\n",
        "\n",
        "eval_scales = defaultdict(lambda: 1)\n",
        "\n",
        "experiments = {}\n",
        "\n",
        "scale =    4#@param {type:\"number\"}\n",
        "use_mask = False  # @param {type:\"boolean\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4X7CRsEVg3qJ"
      },
      "source": [
        "work_dir = gpath.GPath()\n",
        "work_dir.mkdir(exist_ok=True, parents=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nn33mMI1ttgD"
      },
      "source": [
        "\n",
        "def load_dataset_dict(name, scale=scale, use_images=True):\n",
        "  print(f'Loading {name}...')\n",
        "  dataset_dir = dataset_root / name\n",
        "\n",
        "  image_path = dataset_dir / f'rgb/{scale}x'\n",
        "  mask_path = dataset_dir / f'mask/{scale}x'\n",
        "  colmap_mask_path = dataset_dir / f'mask-colmap/{scale}x'\n",
        "\n",
        "  with (dataset_dir / 'dataset.json').open('rt') as f:\n",
        "    dataset_json = json.load(f)\n",
        "\n",
        "  ds = datasets.NerfiesDataSource(dataset_dir, image_scale=scale, camera_type='proto')\n",
        "  train_ids = ds.train_ids\n",
        "  val_ids = ds.val_ids\n",
        "  val_cameras = utils.parallel_map(ds.load_camera, val_ids)\n",
        "\n",
        "  out = {\n",
        "      'name': name,\n",
        "      'scale': scale,\n",
        "      'train_ids': train_ids,\n",
        "      'val_ids': val_ids,\n",
        "      'val_cameras': val_cameras,\n",
        "  }\n",
        "  if use_images:\n",
        "      # out['train_rgbs'] = load_images([image_path / f'{x}.png' for x in dataset_json['train_ids']])\n",
        "      out['val_rgbs'] = load_images([image_path / f'{x}.png' for x in dataset_json['val_ids']])\n",
        "  return out\n",
        "\n",
        "\n",
        "def load_experiment_images(dataset_name, sweep_name, exp_name, item_ids, seed=0, idx=-1):\n",
        "  exp_dir = gpath.GPath(experiment_root, sweep_name, dataset_name, exp_name, f'{seed}')\n",
        "  print(f'Loading experiment images from {exp_dir}')\n",
        "  renders_dir = exp_dir / 'renders'\n",
        "  renders_dir = sorted(renders_dir.iterdir())[idx]\n",
        "  print(f'Experiment step = {int(renders_dir.name)}')\n",
        "  val_renders_dir = renders_dir / 'val'\n",
        "  val_renders = [val_renders_dir / f'rgb_{item_id}.png' for item_id in item_ids]\n",
        "  if any(not x.exists() for x in val_renders):\n",
        "    return []\n",
        "  return utils.parallel_map(load_image, val_renders, show_pbar=False)\n",
        "\n",
        "\n",
        "def compute_experiment_metrics(target_images, pred_images, eval_scale, crop=False):\n",
        "  if eval_scale > 1:\n",
        "    scale_fn = lambda x: image_utils.downsample_image(image_utils.make_divisible(x, eval_scale), eval_scale)\n",
        "  elif eval_scale < 1:\n",
        "    scale_fn = lambda x: image_utils.upsample_image(image_utils.make_divisible(x, eval_scale), int(1/eval_scale))\n",
        "  if eval_scale != 1:\n",
        "    print(f'Downsampling images by a factory of {eval_scale}')\n",
        "    target_images = utils.parallel_map(scale_fn, target_images)\n",
        "    pred_images = utils.parallel_map(scale_fn, pred_images)\n",
        "\n",
        "  metrics_dict = defaultdict(list)\n",
        "  for i, (target, pred) in enumerate(zip(ProgressIter(target_images), pred_images)):\n",
        "    if i == 0:\n",
        "      mediapy.show_images([target, pred], titles=['target', 'pred'])\n",
        "    mse = compute_mse(target, pred)\n",
        "    psnr = mse_to_psnr(mse)\n",
        "    ssim = compute_ssim(target, pred)\n",
        "    lpips = compute_lpips(target, pred)\n",
        "    metrics_dict['psnr'].append(float(psnr))\n",
        "    metrics_dict['lpips'].append(float(lpips))\n",
        "    metrics_dict['ssim'].append(float(ssim))\n",
        "  \n",
        "  return metrics_dict\n",
        "\n",
        "\n",
        "def save_experiment_images(save_dir, images, item_ids):\n",
        "  save_dir.mkdir(exist_ok=True, parents=True)\n",
        "  save_paths = [save_dir / f'{i}.png' for i in item_ids]\n",
        "  utils.parallel_map(\n",
        "      lambda x: image_utils.save_image(x[0], image_utils.image_to_uint8(x[1])),\n",
        "      list(zip(save_paths, images)))\n",
        "\n",
        "\n",
        "def summarize_metrics(metrics_dict):\n",
        "  return {\n",
        "      k: np.mean(v) for k, v in metrics_dict.items()\n",
        "  }"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LRrDmkrmgyTc"
      },
      "source": [
        "## Compute NeRF metrics."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EF7zhxd9Pc8U"
      },
      "source": [
        "def load_nerf_images(dataset_dict, sweep_name, exp_name):\n",
        "  dataset_name = dataset_dict['name']\n",
        "  print(f'Computing metrics for {dataset_name} / {sweep_name} / {exp_name}')\n",
        "  item_ids = dataset_dict['val_ids']\n",
        "  target_images = dataset_dict['val_rgbs']\n",
        "  idx = -1\n",
        "  while idx > -5:\n",
        "    try:\n",
        "      pred_images = load_experiment_images(dataset_name, sweep_name, exp_name, item_ids, idx=idx)\n",
        "      idx -= 1\n",
        "    except (FileNotFoundError, ValueError):\n",
        "      print(f'Latest renders not ready, choosing previous')\n",
        "      pred_images = []\n",
        "    if len(pred_images) == len(target_images):\n",
        "      break\n",
        "  \n",
        "  if len(pred_images) < len(target_images):\n",
        "    raise RuntimeError('Images are not ready.')\n",
        "  \n",
        "  return target_images, pred_images\n",
        "\n",
        "\n",
        "for dataset_name in dataset_names:\n",
        "  eval_scale = eval_scales[dataset_name]\n",
        "  print(f'{dataset_name}, eval_scale={eval_scale }')\n",
        "  dataset_dict = None\n",
        "  for sweep_name, exp_name in experiments:\n",
        "    if sweep_name == 'none':\n",
        "      continue\n",
        "    print(f'Processing {dataset_name} / {sweep_name} / {exp_name}')\n",
        "    cache_dir = work_dir / dataset_name / sweep_name / exp_name\n",
        "    cache_dir.mkdir(exist_ok=True, parents=True)\n",
        "    metrics_path = cache_dir / 'metrics.json'\n",
        "    # Load existing metrics if they exist.\n",
        "    if False and metrics_path.exists():\n",
        "      print(f'Loading cached metrics from {metrics_path}')\n",
        "      with metrics_path.open('rt') as f:\n",
        "        metrics_dict = json.load(f)\n",
        "    else:\n",
        "      # Lazily load dataset dict.\n",
        "      if dataset_dict is None:\n",
        "        dataset_dict = load_dataset_dict(dataset_name)\n",
        "      \n",
        "      # Compute metrics.\n",
        "      target_images, pred_images = load_nerf_images(\n",
        "          dataset_dict, sweep_name, exp_name)\n",
        "\n",
        "      save_experiment_images(cache_dir / 'target_images', target_images, dataset_dict['val_ids'])\n",
        "      save_experiment_images(cache_dir / 'pred_images', pred_images, dataset_dict['val_ids'])\n",
        "      metrics_dict = compute_experiment_metrics(target_images, pred_images, eval_scale=eval_scale)\n",
        "      with metrics_path.open('wt') as f:\n",
        "        json.dump(metrics_dict, f, indent=2)\n",
        "    \n",
        "    print(summarize_metrics(metrics_dict))\n",
        "    print()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zNr9PEcq2X04"
      },
      "source": [
        "## Create table"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sleC2oNszaJO"
      },
      "source": [
        "\n",
        "table_experiments = {}\n",
        "\n",
        "dataset_groups = [{}, {}]\n",
        "table_dataset_names = {}\n",
        "for i, group in enumerate(dataset_groups):\n",
        "  table_dataset_names.update(group)\n",
        "  table_dataset_names[f'mean{i}'] = 'Mean'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JpwlhSgLPBgZ"
      },
      "source": [
        "def load_metric(key):\n",
        "  dataset_name, sweep_name, exp_name = key\n",
        "  cache_dir = work_dir / dataset_name / sweep_name / exp_name\n",
        "  metric_path = cache_dir / 'metrics.json'\n",
        "  with metric_path.open('rt') as f:\n",
        "    return json.load(f)\n",
        "\n",
        "\n",
        "metric_keys = []\n",
        "for dataset_name in dataset_names:\n",
        "  for sweep_name, exp_name in table_experiments:\n",
        "    metric_keys.append((dataset_name, sweep_name, exp_name))\n",
        "\n",
        "metrics_list = utils.parallel_map(load_metric, metric_keys, show_pbar=True)\n",
        "metrics_by_key = {k: v for k, v in zip(metric_keys, metrics_list)}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a4vITSf3OlJp"
      },
      "source": [
        "# Create nested dict.\n",
        "exp_mappings = {e: v for (_, e), v in table_experiments.items()}\n",
        "\n",
        "dataset_metrics_dict = collections.defaultdict(dict)\n",
        "for (dataset_name, sweep_name, exp_name), metric_dict in metrics_by_key.items():\n",
        "  dataset_metrics_dict[dataset_name][exp_name] = summarize_metrics(metric_dict)\n",
        "dataset_dicts = {name: load_dataset_dict(name, use_images=False) for name in dataset_names}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z-QMfuNp0o-q"
      },
      "source": [
        "# @title Table template\n",
        "import jinja2\n",
        "from jinja2 import Template\n",
        "\n",
        "env = jinja2.Environment(\n",
        "\ttrim_blocks = True,\n",
        "\tautoescape = False,\n",
        "  lstrip_blocks = True\n",
        ")\n",
        "\n",
        "template = env.from_string(\"\"\"\n",
        "%%%% AUTOMATICALLY GENERATED, DO NOT EDIT.\n",
        "\n",
        "\\\\begin{tabular}{l|{% for _ in dataset_names %}|{% for _ in metric_names %}c{% endfor %}{% endfor %}}\n",
        "\n",
        "\\\\toprule\n",
        "% Table Header (datasets).\n",
        "{% for dataset_name in dataset_names %}\n",
        "& \\\\multicolumn{ {{metric_names|length}} }{c}{\n",
        "  \\\\makecell{\n",
        "  \\\\textsc{\\\\small {{dataset_names[dataset_name]}} }\n",
        "  {% if 'mean' not in dataset_name %}\n",
        "    \\\\\\\\({{ datasets[dataset_name]['val_ids']|length }} images)\n",
        "  {% endif %}\n",
        "  }\n",
        "}\n",
        "{% endfor %}\n",
        "\\\\\\\\\n",
        "\n",
        "% Table header (metrics).\n",
        "{% for dataset_name in dataset_names %}\n",
        "  {% for metric_name in metric_names.values() %}\n",
        "    & \\\\multicolumn{1}{c}{ \\\\footnotesize {{metric_name}} }\n",
        "  {% endfor %}\n",
        "{% endfor %}\n",
        "\\\\\\\\\n",
        "\\\\hline\n",
        "\n",
        "% Table contents.\n",
        "{% for exp_k, exp_name in experiment_names.items() %}\n",
        "  {% set exp_i = loop.index0 %}\n",
        "  {{exp_name}}\n",
        "  {% for dataset_key, dataset_name in dataset_names.items() %}\n",
        "    {%- for metric_key in metric_names -%}\n",
        "      {% set metrics = dataset_metrics[dataset_key][metric_key] %}\n",
        "      {% if metric_key != 'lpips'%}\n",
        "        {% set rank = (-metrics).argsort().argsort()[exp_i] %}\n",
        "      {% else %}\n",
        "        {% set rank = metrics.argsort().argsort()[exp_i] %}\n",
        "      {% endif %}\n",
        "      &\n",
        "      {%- if rank == 0 -%}\n",
        "        \\\\tablefirst\n",
        "      {%- elif rank == 1 -%}\n",
        "        \\\\tablesecond\n",
        "      {%- elif rank == 2 -%}\n",
        "        \\\\tablethird\n",
        "      {%- endif -%}\n",
        "      ${{\"{:#.03g}\".format(metrics[exp_i]).lstrip('0')}}$\n",
        "    {% endfor %}  \n",
        "  {% endfor %}\n",
        "  \\\\\\\\\n",
        "{% endfor %}\n",
        "\\\\bottomrule\n",
        "\n",
        "\\\\end{tabular}\n",
        "\n",
        "%%%% AUTOMATICALLY GENERATED, DO NOT EDIT.\n",
        "\"\"\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZYhb8QtBzg_4"
      },
      "source": [
        "from pprint import pprint\n",
        "# @title Generate table\n",
        "\n",
        "use_psnr = True # @param{type:'boolean'}\n",
        "use_ssim = False # @param{type:'boolean'}\n",
        "use_lpips = True # @param{type:'boolean'}\n",
        "\n",
        "\n",
        "METRIC_NAMES = {\n",
        "    'psnr': 'PSNR$\\\\uparrow$', \n",
        "    'ssim': 'MS-SSIM$\\\\uparrow$', \n",
        "    'lpips': 'LPIPS$\\\\downarrow$',\n",
        "}\n",
        "if not use_psnr:\n",
        "  del METRIC_NAMES['psnr']\n",
        "if not use_ssim:\n",
        "  del METRIC_NAMES['ssim']\n",
        "if not use_lpips:\n",
        "  del METRIC_NAMES['lpips']\n",
        "\n",
        "\n",
        "table_metrics = {}\n",
        "for group_id, dataset_group in enumerate(dataset_groups):\n",
        "  print(group_id)\n",
        "  group_metrics = collections.defaultdict(dict)\n",
        "  for dataset_k in dataset_group:\n",
        "    for metric_k in ['psnr', 'ssim', 'lpips']:\n",
        "      metric_v = np.array([dataset_metrics_dict[dataset_k][exp_k][metric_k].mean() \n",
        "                          for exp_k in exp_mappings])\n",
        "      group_metrics[dataset_k][metric_k] = metric_v\n",
        "\n",
        "  group_metrics[f'mean{group_id}'] = {\n",
        "      m: np.stack([group_metrics[dk][m] for dk in dataset_group], axis=0).mean(axis=0)\n",
        "      for m in METRIC_NAMES\n",
        "  }\n",
        "  print(group_metrics.keys())\n",
        "  for k, v in group_metrics.items():\n",
        "    table_metrics[k] = v\n",
        " \n",
        "table_str = template.render(\n",
        "    datasets=dataset_dicts,\n",
        "    dataset_names=table_dataset_names,\n",
        "    dataset_metrics=table_metrics,\n",
        "    experiment_names=exp_mappings,\n",
        "    metric_names=METRIC_NAMES\n",
        ").replace('    ','')\n",
        "print(table_str)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s2iQfJ_3V_xK"
      },
      "source": [
        "## Choose visualizations"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OJJXMT0aWIVV"
      },
      "source": [
        "stride = 1\n",
        "dataset_name = 'vrig/broom2'\n",
        "dataset_dict = load_dataset_dict(dataset_name, use_images=True)\n",
        "mediapy.show_images(\n",
        "    dataset_dict['val_rgbs'][::stride], \n",
        "    titles=dataset_dict['val_ids'][::stride])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UhIP-33QyKTO"
      },
      "source": [
        "print(d['val_ids'][10])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ySP7uEmFWphH"
      },
      "source": [
        "dataset_qual_ids = {\n",
        "}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e_4ektTnXdb7"
      },
      "source": [
        "def load_dataset_image(dataset_name, item_id, scale=scale):\n",
        "  dataset_dir = dataset_root / dataset_name\n",
        "  image_path = dataset_dir / f'rgb/{scale}x'\n",
        "  return load_image(image_path / f'{item_id}.png')\n",
        "\n",
        "\n",
        "out_dir = ''\n",
        "out_dir.mkdir(exist_ok=True, parents=True)\n",
        "\n",
        "for dataset_name in dataset_names:\n",
        "  if dataset_name not in dataset_qual_ids:\n",
        "    continue\n",
        "  val_id = dataset_qual_ids[dataset_name]\n",
        "  if 'right' in val_id:\n",
        "    train_id = val_id.replace('right', 'left')\n",
        "  else:\n",
        "    train_id = val_id.replace('left', 'right')\n",
        "\n",
        "  ds = datasets.NerfiesDataSource(dataset_root / dataset_name, image_scale=4, camera_type='proto')\n",
        "  train_rgb = ds.load_rgb(train_id)\n",
        "  train_camera = ds.load_camera(train_id)\n",
        "  val_rgb = ds.load_rgb(val_id)\n",
        "  val_camera = ds.load_camera(val_id)\n",
        "\n",
        "  train_rgb = nv_crop(undistort_image(train_rgb, train_camera)[0])\n",
        "  val_rgb = nv_crop(undistort_image(val_rgb, val_camera)[0])\n",
        "\n",
        "  save_name = dataset_name.split('/')[-1]\n",
        "  mediapy.write_image(str(out_dir / f'{save_name}.train.jpg'), (train_rgb))\n",
        "  mediapy.write_image(str(out_dir / f'{save_name}.valid.jpg'), (val_rgb))\n",
        "  for sweep_name, exp_name in experiments:\n",
        "    cache_dir = work_dir / dataset_name / sweep_name / exp_name\n",
        "    print(cache_dir)\n",
        "    pred_rgb = load_image(cache_dir / 'pred_images' / f'{val_id}.png')\n",
        "    scale = eval_scales[dataset_name]\n",
        "\n",
        "    mediapy.show_images([train_rgb, val_rgb, pred_rgb])\n",
        "    mediapy.write_image(str(out_dir / f'{save_name}.{exp_name}.pred.jpg'), pred_rgb)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UliW_ikSKP5N"
      },
      "source": [
        "### Little number in figure"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XWq5iIjlKK_F"
      },
      "source": [
        "exp_names_to_show = [\n",
        "]\n",
        "\n",
        "datasets_to_show = [\n",
        "]\n",
        "\n",
        "\n",
        "def load_dataset_image(dataset_name, item_id, scale=scale):\n",
        "  dataset_dir = dataset_root / dataset_name\n",
        "  image_path = dataset_dir / f'rgb/{scale}x'\n",
        "  return load_image(image_path / f'{item_id}.png')\n",
        "\n",
        "\n",
        "for dataset_name in datasets_to_show:\n",
        "  val_id = dataset_qual_ids[dataset_name]\n",
        "  if 'right' in val_id:\n",
        "    train_id = val_id.replace('right', 'left')\n",
        "  else:\n",
        "    train_id = val_id.replace('left', 'right')\n",
        "    \n",
        "  scale = eval_scales[dataset_name]\n",
        "    \n",
        "  train_rgb = load_dataset_image(dataset_name, train_id, 4)\n",
        "  print(train_rgb.shape)\n",
        "  \n",
        "  for sweep_name, exp_name in experiments:\n",
        "    if exp_name not in exp_names_to_show:\n",
        "      print('Skipping', exp_name)\n",
        "      continue\n",
        "\n",
        "    cache_dir = work_dir / dataset_name / sweep_name / exp_name\n",
        "    print(cache_dir)\n",
        "    target_rgb = load_image(cache_dir / 'target_images' / f'{val_id}.png')\n",
        "    pred_rgb = load_image(cache_dir / 'pred_images' / f'{val_id}.png')\n",
        "    psnr = mse_to_psnr(compute_mse(target_rgb, pred_rgb))\n",
        "    lpips = compute_lpips(target_rgb, pred_rgb)\n",
        "    print(f'psnr: {float(psnr):#.03g}')\n",
        "    print(f'lpips: {float(lpips):#.03g}')\n",
        "    \n",
        "    mediapy.show_images([target_rgb, pred_rgb])\n",
        "\n",
        "    print()\n",
        "    print()\n",
        "    print()"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}